// Copyright (C) 2024 Kumo inc.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//

#include <kllm/core/dump.h>
#include <kllm/core/params.h>
#include <regex>
#include <thread>

namespace kllm {


    void yaml_dump_vector_float(FILE * stream, const char * prop_name, const std::vector<float> & data) {
        if (data.empty()) {
            fprintf(stream, "%s:\n", prop_name);
            return;
        }

        fprintf(stream, "%s: [", prop_name);
        for (size_t i = 0; i < data.size() - 1; ++i) {
            fprintf(stream, "%e, ", data[i]);
        }
        fprintf(stream, "%e]\n", data.back());
    }

    void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector<int> & data) {
        if (data.empty()) {
            fprintf(stream, "%s:\n", prop_name);
            return;
        }

        fprintf(stream, "%s: [", prop_name);
        for (size_t i = 0; i < data.size() - 1; ++i) {
            fprintf(stream, "%d, ", data[i]);
        }
        fprintf(stream, "%d]\n", data.back());
    }

    void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) {
        std::string data_str(data == NULL ? "" : data);

        if (data_str.empty()) {
            fprintf(stream, "%s:\n", prop_name);
            return;
        }

        size_t pos_start = 0;
        size_t pos_found = 0;

        if (std::isspace(data_str[0]) || std::isspace(data_str.back())) {
            data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
            data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
            data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
            data_str = "\"" + data_str + "\"";
            fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
            return;
        }

        if (data_str.find('\n') == std::string::npos) {
            fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
            return;
        }

        fprintf(stream, "%s: |\n", prop_name);
        while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) {
            fprintf(stream, "  %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str());
            pos_start = pos_found + 1;
        }
    }

    void yaml_dump_non_result_info(FILE * stream, const KMParams & params, const llama_context * lctx,
                                   const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
        ggml_cpu_init(); // some ARM features are detected at runtime

        const auto & sparams = params.sparams;

        fprintf(stream, "cpu_has_arm_fma: %s\n",     ggml_cpu_has_arm_fma()     ? "true" : "false");
        fprintf(stream, "cpu_has_avx: %s\n",         ggml_cpu_has_avx()         ? "true" : "false");
        fprintf(stream, "cpu_has_avx_vnni: %s\n",    ggml_cpu_has_avx_vnni()    ? "true" : "false");
        fprintf(stream, "cpu_has_avx2: %s\n",        ggml_cpu_has_avx2()        ? "true" : "false");
        fprintf(stream, "cpu_has_avx512: %s\n",      ggml_cpu_has_avx512()      ? "true" : "false");
        fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
        fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
        fprintf(stream, "cpu_has_cuda: %s\n",        ggml_cpu_has_cuda()        ? "true" : "false");
        fprintf(stream, "cpu_has_vulkan: %s\n",      ggml_cpu_has_vulkan()      ? "true" : "false");
        fprintf(stream, "cpu_has_kompute: %s\n",     ggml_cpu_has_kompute()     ? "true" : "false");
        fprintf(stream, "cpu_has_fma: %s\n",         ggml_cpu_has_fma()         ? "true" : "false");
        fprintf(stream, "cpu_has_gpublas: %s\n",     ggml_cpu_has_gpublas()     ? "true" : "false");
        fprintf(stream, "cpu_has_neon: %s\n",        ggml_cpu_has_neon()        ? "true" : "false");
        fprintf(stream, "cpu_has_sve: %s\n",         ggml_cpu_has_sve()         ? "true" : "false");
        fprintf(stream, "cpu_has_f16c: %s\n",        ggml_cpu_has_f16c()        ? "true" : "false");
        fprintf(stream, "cpu_has_fp16_va: %s\n",     ggml_cpu_has_fp16_va()     ? "true" : "false");
        fprintf(stream, "cpu_has_riscv_v: %s\n",     ggml_cpu_has_riscv_v()     ? "true" : "false");
        fprintf(stream, "cpu_has_wasm_simd: %s\n",   ggml_cpu_has_wasm_simd()   ? "true" : "false");
        fprintf(stream, "cpu_has_blas: %s\n",        ggml_cpu_has_blas()        ? "true" : "false");
        fprintf(stream, "cpu_has_sse3: %s\n",        ggml_cpu_has_sse3()        ? "true" : "false");
        fprintf(stream, "cpu_has_vsx: %s\n",         ggml_cpu_has_vsx()         ? "true" : "false");
        fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false");

#ifdef NDEBUG
        fprintf(stream, "debug: false\n");
#else
        fprintf(stream, "debug: true\n");
#endif // NDEBUG

        fprintf(stream, "model_desc: %s\n", model_desc);
        fprintf(stream, "n_vocab: %d  # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));

#ifdef __OPTIMIZE__
        fprintf(stream, "optimize: true\n");
#else
        fprintf(stream, "optimize: false\n");
#endif // __OPTIMIZE__

        fprintf(stream, "time: %s\n", timestamp.c_str());

        fprintf(stream, "\n");
        fprintf(stream, "###############\n");
        fprintf(stream, "# User Inputs #\n");
        fprintf(stream, "###############\n");
        fprintf(stream, "\n");

        fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
        fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
        fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
        fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
        fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
        fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
        fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
        fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
        fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
        fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
        fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
        fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
        yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str());
        fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
        fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
        fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
        fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");

        yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
        fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
        yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str());
        fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false");
        fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false");
        fprintf(stream, "keep: %d # default: 0\n", params.n_keep);
        fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());

        fprintf(stream, "logit_bias:\n");
        for (const auto & logit_bias : sparams.logit_bias) {
            fprintf(stream, "  %d: %f", logit_bias.token, logit_bias.bias);
        }

        fprintf(stream, "lora:\n");
        for (auto & la : params.lora_adapters) {
            if (la.scale == 1.0f) {
                fprintf(stream, "  - %s\n", la.path.c_str());
            }
        }
        fprintf(stream, "lora_scaled:\n");
        for (auto & la : params.lora_adapters) {
            if (la.scale != 1.0f) {
                fprintf(stream, "  - %s: %f\n", la.path.c_str(), la.scale);
            }
        }
        fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
        fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
        fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
        fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
        fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
        fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
        fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
        fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
        fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
        fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
        fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
        fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
        fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
        fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
        fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
        fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
        fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
        fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
        yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str());
        fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
        fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
        fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
        yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens);
        fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);

        fprintf(stream, "reverse_prompt:\n");
        for (std::string ap : params.antiprompt) {
            size_t pos = 0;
            while ((pos = ap.find('\n', pos)) != std::string::npos) {
                ap.replace(pos, 1, "\\n");
                pos += 1;
            }

            fprintf(stream, "  - %s\n", ap.c_str());
        }

        fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
        fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
        fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
        fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
        fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
        fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

        const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
        yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);

        fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
        fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
        fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
        fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
        fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
        fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
        fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
        fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
        fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
    }

}  // namespace kllm
